library(tidyverse)
library(knitr)
library(plotly) ; library(viridis) ; library(gridExtra) ; library(RColorBrewer) ; library(ggpubr)
library(mgcv)
library(caret) ; library(ROCR) ; library(car) ; library(MLmetrics)
library(biomaRt)
library(knitr) ; library(kableExtra)
library(ROCR)
library(expss)

SFARI_colour_hue = function(r) {
  pal = c('#FF7631','#FFB100','#E8E328','#8CC83F','#62CCA6','#59B9C9','#b3b3b3','#808080','gray','#d9d9d9')[r]
}
# Gandal dataset
load('./../Data/preprocessed_data.RData')
datExpr = datExpr %>% data.frame
rownames(datExpr) = datGenes$ensembl_gene_id
DE_info = DE_info %>% data.frame
datMeta = datMeta %>% mutate(ID = title)


# Ridge Regression output
load('./../Data/Ridge_model.RData')

# SFARI Genes
SFARI_genes = read_csv('./../../../SFARI/Data/SFARI_genes_01-03-2020_w_ensembl_IDs.csv')
SFARI_genes = SFARI_genes[!duplicated(SFARI_genes$ID) & !is.na(SFARI_genes$ID),]


# GO Neuronal annotations: regex 'neuron' in GO functional annotations and label the genes that make a match as neuronal
GO_annotations = read.csv('./../Data/genes_GO_annotations.csv')
GO_neuronal = GO_annotations %>% filter(grepl('neuron', go_term)) %>% 
              mutate('ID'=as.character(ensembl_gene_id)) %>% 
              dplyr::select(-ensembl_gene_id) %>% distinct(ID) %>%
              mutate('Neuronal'=1)

# Add all this info to predictions
predictions = predictions %>% left_join(SFARI_genes %>% dplyr::select(ID, `gene-score`), by = 'ID') %>%
              mutate(gene.score = ifelse(is.na(`gene-score`), 
                                         ifelse(ID %in% GO_neuronal$ID, 'Neuronal', 'Others'), 
                                         `gene-score`)) %>%
              dplyr::select(-`gene-score`)

clustering_selected = 'DynamicHybrid'
clusterings = read_csv('./../Data/clusters.csv')
clusterings$Module = clusterings[,clustering_selected] %>% data.frame %>% unlist %>% unname
assigned_module = clusterings %>% dplyr::select(ID, Module)

# Add gene symbol
getinfo = c('ensembl_gene_id','external_gene_id')
mart = useMart(biomart='ENSEMBL_MART_ENSEMBL', dataset='hsapiens_gene_ensembl',
               host='feb2014.archive.ensembl.org') ## Gencode v19
gene_names = getBM(attributes=getinfo, filters=c('ensembl_gene_id'), values=rownames(datExpr), mart=mart)


rm(rownames_dataset, GO_annotations, datGenes, dds, clustering_selected,
   clusterings)


Introduction


In 10_classification_model.html we trained a Ridge regression to assign a probability to each gene with the objective of identifying new candidate SFARI Genes based on their gene expression behaviour captured with the WGCNA pipeline

The model seems to perform well (performance metrics can be found in 10_classification_model.html), but we found a bias related to the level of expression of the genes, in general, with the model assigning higher probabilities to genes with higher levels of expression

This is a problem because we had previously discovered a bias in the SFARI scores related to mean level of expression, which means that this could be a confounding factor in our model and the reason why it seems to perform well

mean_and_sd = data.frame(ID=rownames(datExpr), meanExpr=rowMeans(datExpr), sdExpr=apply(datExpr,1,sd))

plot_data = predictions %>% left_join(mean_and_sd, by='ID')

plot_data %>% ggplot(aes(meanExpr, prob)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='loess', color='gray', alpha=0.2) +
              xlab('Mean Expression') + ylab('Probability') + 
              ggtitle('Bias in model probabilities by level of expresion') +
              theme_minimal()

rm(mean_and_sd)



Solutions to Bias Problem


This section is based on the paper Identifying and Correcting Label Bias in Machine Learning


Work in fair classification can be categorised into three approaches:



1. Post-processing Approach


After the model has been trained with the bias, perform a post-processing of the classifier outputs. This approach is quite simple to implement but has some downsides:

  • It has limited flexibility

  • Decoupling the training and calibration can lead to models with poor accuracy tradeoff (when training your model it may be focusing on the bias, in our case mean expression, and overlooking more important aspects of your data, such as biological significance)

Note: This is the approach we are going to try in this Markdown



2. Lagrangian Approach


Transforming the problem into a constrained optimisation problem (fairness as the constraint) using Lagrange multipliers.

Some of the downsides of this approach are:

  • The fairness constraints are often irregular and have to be relaxed in order to optimise

  • Training can be difficult, the Lagrangian may not even have a solution to converge to

  • Constrained optimisation can be inherently unstable

  • It can overfit and have poor fairness generalisation

  • According to the paper, it often yields poor trade-offs in fairness and accuracy

Note: It seems quite complicated and has many downsides, so I’m not going to implement this approach



3. Pre-processing Approach


These approaches primarily involve “massaging” the data to remove bias.

Some downsides are:

  • These approaches typically do not perform as well as the state-of-art and come with few theoretical guarantees

Note: In earlier versions of this code, I implemented this approach by trying to remove the level of expression signal from each feature of the dataset (since the Module Membership features capture the bias in an indirect way), but removing the mean expression signal modified the module membership of the genes in big ways sometimes and it didn’t seem to solve the problem in the end, so this proved not to be very useful and wasn’t implemented in this final version



New Method proposed by the paper (weighting technique)


They introduce a new mathematical framework for fairness in which we assume that there exists an unknown but unbiased group truth label function and that the labels observed in the data are assigned by an agent who is possibly biased, but otherwise has the intention of being accurate

Assigning appropriate weights to each sample in the training data and iteratively training a classifier with the new weighted samples leads to an unbiased classifier on the original un-weighted dataset that simultaneously minimises the weighted loss and maximises fairness

Advantages:

  • This approach works also on settings where both the features and the labels are biased

  • It can be used with many ML algorithms

  • It can be applied to many notions of fairness

  • It doesn’t have strict assumptions about the behaviour of the data or the labels

  • According to the paper, it’s fast and robust

  • According to the paper, it consistently leads to fairer classifiers, as well as a better or comparative predictive error than the other methods


Also, this is not important, but I though it was interesting: Since the algorithm simultaneously minimises the weighted loss and maximises fairness via learning the coefficients, it may be interpreted as competing goals with different objective functions, this, it’s a form of a non-zero-sum two-player game

Note: Implemented in 14_bias_correciton_weighting_technique.html






Post Processing Approach Implementation



After the model has been trained with the bias, perform a post-processing of the classifier outputs

Since the effect of the bias is proportional to the mean level of expression of a gene, we can correct it by removing the effect of the mean expression from the probability of the model

Problems:


Remove Bias


The relation between level of expression and probability assigned by the model could be modelled by a linear regression, but we would lose some of the behaviour. Fitting a curve using Generalised Additive Models seems to capture the relation in a much better way, with an \(R^2\) twice as large and no recognisable pattern in the residuals of the regression

test_set = predictions
old_predictions = predictions

plot_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% 
            right_join(test_set, by='ID')

# Fit linear and GAM models to data
lm_fit = lm(prob ~ meanExpr, data = plot_data)
gam_fit = gam(prob ~ s(meanExpr), method = 'REML', data = plot_data)

plot_data = plot_data %>% mutate(lm_res = lm_fit$residuals, gam_res = gam_fit$residuals)

# Plot data
p1 = plot_data %>% ggplot(aes(meanExpr, prob)) + geom_point(alpha=0.1, color='#0099cc') + geom_smooth(method='lm', color='gray', alpha=0.3) +
     xlab('Mean Expression') + ylab('Probability') + ggtitle('Linear Fit') + theme_minimal()

p2 = plot_data %>% ggplot(aes(meanExpr, prob)) + geom_point(alpha=0.1, color='#0099cc') + geom_smooth(method='gam', color='gray', alpha=0.3) +
     xlab('Mean Expression') + ylab('Probability') + ggtitle('GAM fit') + theme_minimal()

p3 = plot_data %>% ggplot(aes(meanExpr, lm_res)) + geom_point(alpha=0.1, color='#ff9900') + 
     geom_smooth(method='gam', color='gray', alpha=0.3) + xlab('Mean Expression') +
     ylab('Residuals') + theme_minimal() + ggtitle(bquote(paste(R^{2},' = ', .(round(summary(lm_fit)$r.squared, 4)))))

p4 = plot_data %>% ggplot(aes(meanExpr, gam_res)) + geom_point(alpha=0.1, color='#ff9900') + 
     geom_smooth(method='gam', color='gray', alpha=0.3) + xlab('Mean Expression') +
     ylab('Residuals') + theme_minimal() + ggtitle(bquote(paste(R^{2},' = ', .(round(summary(gam_fit)$r.sq, 4)))))

grid.arrange(p1, p2, p3, p4, nrow = 2)

rm(p1, p2, p3, p4, lm_fit)


Remove bias from scores with GAM fit


  • Assigning the residuals of the GAM model as the new model probability

  • Adding the mean probability of the original model to each new probability so our new probabilities have the same mean as the original ones

  • As with the plot above, the relation between mean expression and the probability assigned by the model is gone

# Correct Bias
test_set$corrected_score = gam_fit$residuals + mean(test_set$prob)

# Plot results
plot_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% 
            right_join(test_set, by='ID')

plot_data %>% ggplot(aes(meanExpr, corrected_score)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='gam', color='gray', alpha=0.3) + ylab('Corrected Score') + xlab('Mean Expression') +
              theme_minimal() + ggtitle('Mean expression vs Model score corrected using GAM')

rm(gam_fit)


We could use this corrected scores directly to study the performance of the bias-corrected model, but we wouldn’t have the standard deviation of the performance metrics as we had in 10_classification_model.html where we ran the model several times. To have them here as well, I’m going to run the model many times, correcting the bias in each run


Ridge Regression with Post Processing Bias Correction


Notes:

  • Running the model multiple times to get more acurate measurements of its performance

  • Over-sampling positive samples in the training set to obtain a 1:1 class ratio using SMOTE

  • Performing 10 repetitions of cross validation with 10-folds each

  • Correcting the mean expression bias in each run using the preprocessing approach

### DEFINE FUNCTIONS

create_train_test_sets = function(p, seed){
  
  # Get SFARI Score of all the samples so our train and test sets are balanced for each score
  sample_scores = dataset %>% mutate(ID = rownames(.)) %>% dplyr::select(ID) %>%
                  left_join(original_dataset %>% dplyr::select(ID, gene.score), by = 'ID') %>% 
                  mutate(gene.score = ifelse(is.na(gene.score), 'None', gene.score))

  set.seed(seed)
  train_idx = createDataPartition(sample_scores$gene.score, p = p, list = FALSE)
  
  train_set = dataset[train_idx,]
  test_set = dataset[-train_idx,]
  
  return(list('train_set' = train_set, 'test_set' = test_set))
}



run_model = function(p, seed){
  
  # Create train and test sets
  train_test_sets = create_train_test_sets(p, seed)
  train_set = train_test_sets[['train_set']]
  test_set = train_test_sets[['test_set']]
  
  # Train Model
  train_set = train_set %>% mutate(SFARI = ifelse(SFARI==TRUE, 'SFARI', 'not_SFARI') %>% as.factor)
  lambda_seq = 10^seq(1, -4, by = -.1)
  set.seed(seed)
  k_fold = 10
  cv_repeats = 5
  smote_over_sampling = trainControl(method = 'repeatedcv', number = k_fold, repeats = cv_repeats,
                                     verboseIter = FALSE, classProbs = TRUE, savePredictions = 'final', 
                                     summaryFunction = twoClassSummary, sampling = 'smote')
  fit = train(SFARI ~., data = train_set, method = 'glmnet', trControl = smote_over_sampling, metric = 'ROC',
              tuneGrid = expand.grid(alpha = 0, lambda = lambda_seq))
  
  # Predict labels in test set
  predictions = fit %>% predict(test_set, type = 'prob')
  preds = data.frame('ID' = rownames(test_set), 'prob' = predictions$SFARI) %>% mutate(pred = prob>0.5)
  
  
  #############################################################################################################
  # Correct Mean Expression Bias in predictions
  bias_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% right_join(preds, by='ID')
  gam_fit = gam(prob ~ s(meanExpr), method = 'REML', data = bias_data)
  preds$corrected_prob = gam_fit$residuals + mean(preds$prob)
  preds$corrected_pred = preds$prob>0.5
  #############################################################################################################
  

  # Measure performance of the model
  acc = mean(test_set$SFARI==preds$corrected_pred)
  prec = Precision(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  rec = Recall(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  F1 = F1_Score(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  pred_ROCR = prediction(preds$corrected_prob, test_set$SFARI)
  AUC = performance(pred_ROCR, measure='auc')@y.values[[1]]
  
  # Extract coefficients from features
  coefs = coef(fit$finalModel, fit$bestTune$lambda) %>% as.vector
  
  return(list('acc' = acc, 'prec' = prec, 'rec' = rec, 'F1' = F1, 
              'AUC' = AUC, 'preds' = preds, 'coefs' = coefs))
}


### RUN MODEL

# Parameters
p = 0.75
n_iter = 25
seeds = 123:(123+n_iter-1)

# So the input is the same as in 10_classification_model.html
original_dataset = dataset %>% mutate(ID = rownames(.)) %>% 
                   left_join(old_predictions %>% dplyr::select(ID, gene.score))

# Store outputs
acc = c()
prec = c()
rec = c()
F1 = c()
AUC = c()
predictions = data.frame('ID' = rownames(dataset), 'SFARI' = dataset$SFARI, 'prob' = 0, 'pred' = 0,
                         'corrected_prob' = 0, 'corrected_pred' = 0, 'n' = 0)
coefs = data.frame('var' = c('Intercept', colnames(dataset[,-ncol(dataset)])), 'coef' = 0)

for(seed in seeds){
  
  # Run model
  model_output = run_model(p, seed)
  
  # Update outputs
  acc = c(acc, model_output[['acc']])
  prec = c(prec, model_output[['prec']])
  rec = c(rec, model_output[['rec']])
  F1 = c(F1, model_output[['F1']])
  AUC = c(AUC, model_output[['AUC']])
  preds = model_output[['preds']]
  coefs$coef = coefs$coef + model_output[['coefs']]
  update_preds = preds %>% dplyr::select(-ID) %>% mutate(n=1)
  predictions[predictions$ID %in% preds$ID, c('prob','pred','corrected_prob','corrected_pred','n')] = 
    predictions[predictions$ID %in% preds$ID, c('prob','pred','corrected_prob','corrected_pred','n')] +
     update_preds
}

coefs = coefs %>% mutate(coef = coef/n_iter)
predictions = predictions %>% mutate(prob = prob/n, pred_count = pred, pred = prob>0.5,
                                     corrected_prob = corrected_prob/n, corrected_pred_count = corrected_pred, 
                                     corrected_pred = corrected_prob>0.5)


rm(p, seeds, update_preds, create_train_test_sets, run_model)
test_set = predictions %>% filter(n>0) %>% 
           left_join(dataset %>% mutate(ID = rownames(.)) %>% dplyr::select(ID, GS, MTcor), by = 'ID')
rownames(test_set) = predictions$ID[predictions$n>0]


Performance metrics


Confusion matrix

conf_mat = test_set %>% apply_labels(SFARI = 'Actual Labels', 
                                     corrected_prob = 'Assigned Probability', 
                                     corrected_pred = 'Label Prediction')

cro(conf_mat$SFARI, list(conf_mat$corrected_pred, total()))
 Label Prediction     #Total 
 FALSE   TRUE   
 Actual Labels 
   FALSE  10165 2390   12555
   TRUE  436 204   640
   #Total cases  10601 2594   13195
rm(conf_mat)


Accuracy: Mean = 0.7772 SD = 0.0108


Precision: Mean = 0.0854 SD = 0.0068


Recall: Mean = 0.3702 SD = 0.0334


F1 score: Mean = 0.1387 SD = 0.0109


ROC Curve: Mean = 0.5958 SD = 0.0224

pred_ROCR = prediction(test_set$corrected_prob, test_set$SFARI)

roc_ROCR = performance(pred_ROCR, measure='tpr', x.measure='fpr')
auc = performance(pred_ROCR, measure='auc')@y.values[[1]]

plot(roc_ROCR, main=paste0('ROC curve (AUC=',round(mean(AUC),2),')'), col='#009999')
abline(a=0, b=1, col='#666666')


Lift Curve

lift_ROCR = performance(pred_ROCR, measure='lift', x.measure='rpp')
plot(lift_ROCR, main='Lift curve', col='#86b300')

rm(pred_ROCR, roc_ROCR, AUC, lift_ROCR)




Analyse Results


Score distribution by SFARI Label


SFARI genes have a higher score distribution than the rest, but the overlap is larger than before

plot_data = test_set %>% dplyr::select(corrected_prob, SFARI)

ggplotly(plot_data %>% ggplot(aes(corrected_prob, fill=SFARI, color=SFARI)) + geom_density(alpha=0.3) + 
         geom_vline(xintercept = mean(plot_data$corrected_prob[plot_data$SFARI]), color = '#00C0C2', 
                    linetype='dashed') +
         geom_vline(xintercept = mean(plot_data$corrected_prob[!plot_data$SFARI]), color = '#FF7371', 
                    linetype='dashed') +
        xlab('Score') + ggtitle('Model score distribution by SFARI Label') + theme_minimal())


Score distribution by SFARI Gene Scores


The relation between probability and SFARI Gene Scores weakened but it’s still there

plot_data = test_set %>% mutate(ID=rownames(test_set)) %>% dplyr::select(ID, corrected_prob) %>%
            left_join(original_dataset, by='ID') %>% dplyr::select(ID, corrected_prob, gene.score) %>% 
            apply_labels(gene.score='SFARI Gene score')

cro(plot_data$gene.score)
 #Total 
 SFARI Gene score 
   1  105
   2  168
   3  364
   Neuronal  782
   Others  11762
   #Total cases  13181
mean_vals = plot_data %>% group_by(gene.score) %>% summarise(mean_prob = mean(corrected_prob))

comparisons = list(c('1','2'), c('2','3'), c('3','Neuronal'), c('Neuronal','Others'),
                   c('1','3'), c('3','Others'), c('2','Neuronal'),
                   c('1','Neuronal'), c('2','Others'), c('1','Others'))
increase = 0.07
base = 0.75
pos_y_comparisons = c(rep(base, 4), rep(base + increase, 2), base + 2:5*increase)

plot_data %>% filter(!is.na(gene.score)) %>% ggplot(aes(gene.score, corrected_prob, fill=gene.score)) + 
              geom_boxplot(outlier.colour='#cccccc', outlier.shape='o', outlier.size=3) +
              stat_compare_means(comparisons = comparisons, label = 'p.signif', method = 't.test', 
                                 method.args = list(var.equal = FALSE), label.y = pos_y_comparisons, 
                                 tip.length = .02) +
              scale_fill_manual(values=SFARI_colour_hue(r=c(1:3,8,7))) + 
              ggtitle('Distribution of probabilities by SFARI score') +
              xlab('SFARI score') + ylab('Probability') + theme_minimal() + theme(legend.position = 'none')

rm(mean_vals, increase, base, pos_y_comparisons)


Genes with the highest Probabilities


  • The concentration of SFARI genes remained the same (1:4)

  • The genes with the highest probabilities are no longer SFARI Genes

test_set %>% dplyr::select(corrected_prob, SFARI) %>% mutate(ID = rownames(test_set)) %>% 
             arrange(desc(corrected_prob)) %>% top_n(50, wt=corrected_prob) %>%
             left_join(old_predictions %>% dplyr::select(ID, gene.score, external_gene_id, MTcor, GS), 
                       by = 'ID') %>%
             dplyr::rename('GeneSymbol' = external_gene_id, 'Probability' = corrected_prob, 
                           'ModuleDiagnosis_corr' = MTcor, 'GeneSignificance' = GS) %>%
             mutate(ModuleDiagnosis_corr = round(ModuleDiagnosis_corr,4), Probability = round(Probability,4), 
                    GeneSignificance = round(GeneSignificance,4)) %>%
             left_join(assigned_module, by = 'ID') %>%
             dplyr::select(GeneSymbol, GeneSignificance, ModuleDiagnosis_corr, Module, Probability,
                           gene.score) %>%
             kable(caption = 'Genes with highest model probabilities from the test set') %>% 
             kable_styling(full_width = F)
Genes with highest model probabilities from the test set
GeneSymbol GeneSignificance ModuleDiagnosis_corr Module Probability gene.score
FA2H -0.2586 -0.2919 #00BFC3 0.8050 Others
SRGAP1 0.3651 0.1211 #FB727C 0.8041 Others
CACNA1D -0.4023 -0.2919 #00BFC3 0.8040 2
IGF1 -0.0617 -0.0683 #FF6B93 0.7838 Others
MYRF -0.3458 -0.2919 #00BFC3 0.7821 Others
PIEZO2 -0.0962 -0.2919 #00BFC3 0.7779 Others
SH3TC2 -0.3946 -0.2919 #00BFC3 0.7646 Others
TRIM59 0.0192 -0.2919 #00BFC3 0.7640 Others
CDH19 -0.1403 -0.2919 #00BFC3 0.7624 Others
LDB3 -0.4007 -0.2919 #00BFC3 0.7619 Others
NCKAP1L -0.3043 -0.4038 #00B4EF 0.7601 Others
DOCK5 -0.3920 -0.2919 #00BFC3 0.7576 Others
LRRC63 0.0787 -0.5722 #00BE71 0.7558 Others
PLD1 -0.1953 -0.2919 #00BFC3 0.7505 Others
GALNT6 0.0231 -0.2919 #00BFC3 0.7484 Others
ATP8B4 -0.0444 -0.5016 #7FAE00 0.7472 Others
RPRD2 0.1194 0.3738 #FC7181 0.7468 Others
HHIP -0.1148 -0.2919 #00BFC3 0.7455 Others
KIAA1217 0.3683 0.5621 #4BB400 0.7454 Others
PRELP 0.2110 0.5125 #89AC00 0.7448 Others
GALNTL6 0.0407 -0.5008 #00BADF 0.7444 Others
GABBR2 0.1204 0.2418 #A68AFF 0.7440 3
PLXNC1 0.5660 0.5125 #89AC00 0.7431 Others
SIAH3 0.2476 0.0860 #96A900 0.7410 Others
TRPC6 0.1888 0.4910 #E28900 0.7386 2
CSF2RA -0.3589 -0.4038 #00B4EF 0.7386 Others
CALB2 0.1147 -0.5008 #00BADF 0.7380 Others
PLXDC2 0.3497 0.5272 #F564E4 0.7359 Others
CMTM5 -0.2384 -0.2919 #00BFC3 0.7351 Others
CPO -0.0223 0.1211 #FB727C 0.7349 Others
ACMSD 0.1016 0.6982 #FF68A0 0.7347 Others
CARNS1 -0.3334 -0.2919 #00BFC3 0.7339 Others
DTD1 -0.0972 -0.2897 #00BCD6 0.7332 Others
RGS8 -0.1846 -0.0683 #FF6B93 0.7318 Others
WSCD2 0.2707 0.3292 #EC823A 0.7278 Others
EVI2A -0.0599 -0.2919 #00BFC3 0.7251 Others
CMTM7 -0.1504 -0.0683 #FF6B93 0.7243 Others
TRDN 0.2554 0.5125 #89AC00 0.7223 Others
TP53I11 -0.1272 0.5125 #89AC00 0.7201 Others
C10orf90 -0.4584 -0.2919 #00BFC3 0.7197 Others
FRMD4B -0.1152 -0.2919 #00BFC3 0.7196 Others
GLRA2 0.2491 0.4910 #E28900 0.7194 3
PLEKHH1 -0.2875 -0.2919 #00BFC3 0.7192 Others
P2RY13 -0.0886 -0.1865 #E08A00 0.7187 Others
PRSS48 0.1500 0.1211 #FB727C 0.7187 Others
NHSL2 0.4562 0.2203 #F17D50 0.7183 Others
TTYH2 -0.1800 -0.2919 #00BFC3 0.7181 Others
HLA-DPA1 -0.1306 -0.1793 #26B700 0.7166 Others
FAM196A 0.1567 -0.0102 #00BDD3 0.7149 Others
CACNA1C -0.1527 0.2418 #A68AFF 0.7142 1





Negative samples distribution


The objective of this model is to identify candidate SFARI genes. For this, we are going to focus on the negative samples (the non-SFARI genes)

negative_set = test_set %>% filter(!SFARI)

negative_set_table = negative_set %>% apply_labels(corrected_prob = 'Assigned Probability', 
                                                   corrected_pred = 'Label Prediction')

cro(negative_set_table$corrected_pred)
 #Total 
 Label Prediction 
   FALSE  10165
   TRUE  2390
   #Total cases  12555

2447 genes are predicted as ASD-related


negative_set %>% ggplot(aes(corrected_prob)) + geom_density(color='#F8766D', fill='#F8766D', alpha=0.5) +
                 geom_vline(xintercept=0.5, color='#333333', linetype='dotted') + xlab('Probability') +
                 ggtitle('Probability distribution of the Negative samples in the Test Set') + 
                 theme_minimal()


negative_set %>% dplyr::select(corrected_prob, SFARI) %>% mutate(ID = rownames(negative_set)) %>% 
                 arrange(desc(corrected_prob)) %>% top_n(50, wt=corrected_prob) %>%
                 left_join(original_dataset %>% dplyr::select(ID, gene.score, MTcor, GS), 
                           by = 'ID') %>%
                 left_join(gene_names, by = c('ID'='ensembl_gene_id')) %>%
                 dplyr::rename('GeneSymbol' = external_gene_id, 'Probability' = corrected_prob, 
                               'ModuleDiagnosis_corr' = MTcor, 'GeneSignificance' = GS) %>%
                 mutate(ModuleDiagnosis_corr = round(ModuleDiagnosis_corr,4), 
                        Probability = round(Probability,4), 
                        GeneSignificance = round(GeneSignificance,4)) %>%
                 left_join(assigned_module, by = 'ID') %>%
                 dplyr::select(GeneSymbol, GeneSignificance, ModuleDiagnosis_corr, Module, Probability,
                               gene.score) %>%
                 kable(caption = 'Genes with highest model probabilities from the Negative set') %>% 
                 kable_styling(full_width = F)
Genes with highest model probabilities from the Negative set
GeneSymbol GeneSignificance ModuleDiagnosis_corr Module Probability gene.score
FA2H -0.2586 -0.2919 #00BFC3 0.8050 Others
SRGAP1 0.3651 0.1211 #FB727C 0.8041 Others
IGF1 -0.0617 -0.0683 #FF6B93 0.7838 Others
MYRF -0.3458 -0.2919 #00BFC3 0.7821 Others
PIEZO2 -0.0962 -0.2919 #00BFC3 0.7779 Others
SH3TC2 -0.3946 -0.2919 #00BFC3 0.7646 Others
TRIM59 0.0192 -0.2919 #00BFC3 0.7640 Others
CDH19 -0.1403 -0.2919 #00BFC3 0.7624 Others
LDB3 -0.4007 -0.2919 #00BFC3 0.7619 Others
NCKAP1L -0.3043 -0.4038 #00B4EF 0.7601 Others
DOCK5 -0.3920 -0.2919 #00BFC3 0.7576 Others
LRRC63 0.0787 -0.5722 #00BE71 0.7558 Others
PLD1 -0.1953 -0.2919 #00BFC3 0.7505 Others
GALNT6 0.0231 -0.2919 #00BFC3 0.7484 Others
ATP8B4 -0.0444 -0.5016 #7FAE00 0.7472 Others
RPRD2 0.1194 0.3738 #FC7181 0.7468 Others
HHIP -0.1148 -0.2919 #00BFC3 0.7455 Others
KIAA1217 0.3683 0.5621 #4BB400 0.7454 Others
PRELP 0.2110 0.5125 #89AC00 0.7448 Others
GALNTL6 0.0407 -0.5008 #00BADF 0.7444 Others
PLXNC1 0.5660 0.5125 #89AC00 0.7431 Others
SIAH3 0.2476 0.0860 #96A900 0.7410 Others
CSF2RA -0.3589 -0.4038 #00B4EF 0.7386 Others
CALB2 0.1147 -0.5008 #00BADF 0.7380 Others
PLXDC2 0.3497 0.5272 #F564E4 0.7359 Others
CMTM5 -0.2384 -0.2919 #00BFC3 0.7351 Others
CPO -0.0223 0.1211 #FB727C 0.7349 Others
ACMSD 0.1016 0.6982 #FF68A0 0.7347 Others
CARNS1 -0.3334 -0.2919 #00BFC3 0.7339 Others
DTD1 -0.0972 -0.2897 #00BCD6 0.7332 Others
RGS8 -0.1846 -0.0683 #FF6B93 0.7318 Others
WSCD2 0.2707 0.3292 #EC823A 0.7278 Others
EVI2A -0.0599 -0.2919 #00BFC3 0.7251 Others
CMTM7 -0.1504 -0.0683 #FF6B93 0.7243 Others
TRDN 0.2554 0.5125 #89AC00 0.7223 Others
TP53I11 -0.1272 0.5125 #89AC00 0.7201 Others
C10orf90 -0.4584 -0.2919 #00BFC3 0.7197 Others
FRMD4B -0.1152 -0.2919 #00BFC3 0.7196 Others
PLEKHH1 -0.2875 -0.2919 #00BFC3 0.7192 Others
P2RY13 -0.0886 -0.1865 #E08A00 0.7187 Others
PRSS48 0.1500 0.1211 #FB727C 0.7187 Others
NHSL2 0.4562 0.2203 #F17D50 0.7183 Others
TTYH2 -0.1800 -0.2919 #00BFC3 0.7181 Others
HLA-DPA1 -0.1306 -0.1793 #26B700 0.7166 Others
FAM196A 0.1567 -0.0102 #00BDD3 0.7149 Others
ST18 -0.3694 -0.2919 #00BFC3 0.7139 Others
AIF1 -0.0598 -0.2368 #FF61C4 0.7135 Others
GPR34 -0.0705 -0.4038 #00B4EF 0.7128 Others
PTPLAD2 -0.0614 0.6982 #FF68A0 0.7123 Others
SHISA9 -0.0285 0.2203 #F17D50 0.7117 Neuronal




Comparison with the original model’s probabilities:

  • The genes with probabilities close to 0.5 were affected the most as a group

  • More genes got their score increased than decreased but on average, the ones that got it decreased had a bigger change

negative_set %>% mutate(diff = abs(prob-corrected_prob)) %>% 
             ggplot(aes(prob, corrected_prob, color = diff)) + geom_point(alpha=0.2) + scale_color_viridis() + 
             geom_abline(slope=1, intercept=0, color='gray', linetype='dashed') + 
             geom_smooth(color='#666666', alpha=0.5, se=TRUE, size=0.5) + coord_fixed() +
             xlab('Original probability') + ylab('Corrected probability') + theme_minimal() + theme(legend.position = 'none')

negative_set_table = negative_set %>% apply_labels(corrected_prob = 'Corrected Probability', 
                                                   corrected_pred = 'Corrected Class Prediction',
                                                   pred = 'Original Class Prediction')

cro(negative_set_table$pred, list(negative_set_table$corrected_pred, total()))
 Corrected Class Prediction     #Total 
 FALSE   TRUE   
 Original Class Prediction 
   FALSE  9841 267   10108
   TRUE  324 2123   2447
   #Total cases  10165 2390   12555

95% of the genes maintained their original predicted class

rm(negative_set_table)

Probability and Gene Significance


There is no noticeable difference between the trends

*The transparent verison of the trend line is the original trend line

negative_set %>% ggplot(aes(corrected_prob, GS, color=MTcor)) + geom_point() + 
                 geom_smooth(method='gam', color='#666666') + ylab('Gene Significance') +
                 geom_line(stat='smooth', method='gam', color='#666666', alpha=0.5, size=1.2, aes(x=prob)) +
                 geom_hline(yintercept=mean(negative_set$GS), color='gray', linetype='dashed') +
                 scale_color_gradientn(colours=c('#F8766D','white','#00BFC4')) + xlab('Corrected Score') +
                 ggtitle('Relation between the Model\'s Corrected Score and Gene Significance') +
                 theme_minimal()

Summarised version of score vs mean expression, plotting by module instead of by gene

The difference in the trend lines between this plot and the one above is that the one above takes all the points into consideration while this considers each module as an observation by itself, so the top one is strongly affected by big modules and the bottom one treats all modules the same

The transparent version of each point and trend lines are the original values and trends before the bias correction

plot_data = negative_set %>% mutate(ID = rownames(.)) %>% left_join(assigned_module, by = 'ID') %>%
            group_by(MTcor, Module) %>% summarise(mean = mean(prob), sd = sd(prob),
                                                  new_mean = mean(corrected_prob),
                                                  new_sd = sd(corrected_prob), n = n()) %>%
            mutate(MTcor_sign = ifelse(MTcor>0, 'Positive', 'Negative')) %>% 
            dplyr::select(Module, MTcor, MTcor_sign, mean, new_mean, sd, new_sd, n) %>% distinct()
colnames(plot_data)[1] = 'ID'

ggplotly(plot_data %>% ggplot(aes(MTcor, new_mean, size=n, color=MTcor_sign)) + geom_point(aes(id = ID)) + 
         geom_smooth(method='loess', color='gray', se=FALSE) + geom_smooth(method='lm', se=FALSE) + 
         geom_point(aes(y=mean), alpha=0.3) + 
         geom_line(stat='smooth', method='loess', color='gray', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         geom_line(stat='smooth', method='lm', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         xlab('Module-Diagnosis correlation') + ylab('Mean Corrected Score by Module') + 
         theme_minimal() + theme(legend.position='none'))


Probability and mean level of expression


Check if correcting by gene also corrected by module: Yes, but not enough to remove the bias completely

mean_and_sd = data.frame(ID=rownames(datExpr), meanExpr=rowMeans(datExpr), sdExpr=apply(datExpr,1,sd))

plot_data = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>% 
            left_join(mean_and_sd, by='ID') %>% 
            left_join(assigned_module, by='ID')

plot_data2 = plot_data %>% group_by(Module) %>% summarise(meanExpr = mean(meanExpr), meanProb = mean(prob), 
                                                          new_meanProb = mean(corrected_prob), n=n())

ggplotly(plot_data2 %>% ggplot(aes(meanExpr, new_meanProb, size=n)) + 
         geom_point(color=plot_data2$Module) + geom_point(color=plot_data2$Module, alpha=0.3, aes(y=meanProb)) + 
         geom_smooth(method='loess', se=TRUE, color='gray', alpha=0.1, size=0.7) + 
         geom_line(stat='smooth', method='loess', se=TRUE, color='gray', alpha=0.4, size=1.2, aes(y=meanProb)) +
         xlab('Mean Expression') + ylab('Corrected Probability') +  
         ggtitle('Mean expression vs corrected Model score by Module') +
         theme_minimal() + theme(legend.position='none'))
rm(plot_data2, mean_and_sd)


Probability and LFC


The relation seems to have gotten a bit stronger for both over- and under-expressed genes

plot_data = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>% 
            left_join(DE_info %>% data.frame %>% mutate(ID=rownames(datExpr), by='ID')) %>%
            dplyr::rename('log2FoldChange'= logFC, 'padj' = adj.P.Val)

plot_data %>% ggplot(aes(log2FoldChange, corrected_prob)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='loess', color='gray', alpha=0.1) + 
              geom_line(stat='smooth', method='loess', color='gray', alpha=0.4, size=1.5, aes(y=prob)) +
              xlab('LFC') + ylab('Corrected Probability') +
              theme_minimal() + ggtitle('LFC vs model probability by gene')


Probability and Module-Diagnosis correlation


Not much change

module_score = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>%
               left_join(old_predictions %>% dplyr::select(ID, gene.score), by='ID') %>%
               left_join(assigned_module, by = 'ID') %>%
               dplyr::select(ID, prob, corrected_prob, Module, MTcor) %>% 
               left_join(data.frame(MTcor=unique(dataset$MTcor)) %>% arrange(by=MTcor) %>% 
                         mutate(order=1:length(unique(dataset$MTcor))), by='MTcor')

ggplotly(module_score %>% ggplot(aes(MTcor, corrected_prob)) + 
         geom_point(color=module_score$Module, aes(id=ID, alpha=corrected_prob^4)) +
         geom_hline(yintercept=mean(module_score$corrected_prob), color='gray', linetype='dotted') + 
         geom_line(stat='smooth', method = 'loess', color='gray', alpha=0.5, size=1.5, aes(x=MTcor, y=prob)) +
         geom_smooth(color='gray', method = 'loess', se = FALSE, alpha=0.3) + theme_minimal() + 
         xlab('Module-Diagnosis correlation') + ylab('Corrected Score'))



Conclusion


This bias correction seems to be working partially but not entirely, it doesn’t make a big change in the performance of the model


Saving results

write.csv(test_set, file='./../Data/RM_post_proc_bias_correction.csv', row.names = TRUE)




Session info

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 18.04.4 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
## 
## locale:
##  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
##  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
##  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] grid      stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] DMwR_0.4.1         expss_0.10.2       kableExtra_1.1.0   biomaRt_2.40.5    
##  [5] MLmetrics_1.1.1    car_3.0-7          carData_3.0-3      ROCR_1.0-7        
##  [9] gplots_3.0.3       caret_6.0-86       lattice_0.20-41    mgcv_1.8-31       
## [13] nlme_3.1-147       ggpubr_0.2.5       magrittr_1.5       RColorBrewer_1.1-2
## [17] gridExtra_2.3      viridis_0.5.1      viridisLite_0.3.0  plotly_4.9.2      
## [21] knitr_1.28         forcats_0.5.0      stringr_1.4.0      dplyr_1.0.0       
## [25] purrr_0.3.4        readr_1.3.1        tidyr_1.1.0        tibble_3.0.1      
## [29] ggplot2_3.3.2      tidyverse_1.3.0   
## 
## loaded via a namespace (and not attached):
##   [1] readxl_1.3.1         backports_1.1.8      plyr_1.8.6          
##   [4] lazyeval_0.2.2       splines_3.6.3        crosstalk_1.1.0.1   
##   [7] digest_0.6.25        foreach_1.5.0        htmltools_0.4.0     
##  [10] gdata_2.18.0         fansi_0.4.1          checkmate_2.0.0     
##  [13] memoise_1.1.0        openxlsx_4.1.4       recipes_0.1.10      
##  [16] modelr_0.1.6         gower_0.2.1          matrixStats_0.56.0  
##  [19] xts_0.12-0           prettyunits_1.1.1    colorspace_1.4-1    
##  [22] blob_1.2.1           rvest_0.3.5          haven_2.2.0         
##  [25] xfun_0.12            crayon_1.3.4         RCurl_1.98-1.2      
##  [28] jsonlite_1.7.0       zoo_1.8-8            survival_3.1-12     
##  [31] iterators_1.0.12     glue_1.4.1           gtable_0.3.0        
##  [34] ipred_0.9-9          webshot_0.5.2        shape_1.4.4         
##  [37] quantmod_0.4.17      BiocGenerics_0.30.0  abind_1.4-5         
##  [40] scales_1.1.1         DBI_1.1.0            Rcpp_1.0.4.6        
##  [43] progress_1.2.2       htmlTable_1.13.3     foreign_0.8-76      
##  [46] bit_1.1-15.2         stats4_3.6.3         lava_1.6.7          
##  [49] prodlim_2019.11.13   glmnet_3.0-2         htmlwidgets_1.5.1   
##  [52] httr_1.4.1           ellipsis_0.3.1       pkgconfig_2.0.3     
##  [55] XML_3.99-0.3         farver_2.0.3         nnet_7.3-14         
##  [58] dbplyr_1.4.2         tidyselect_1.1.0     labeling_0.3        
##  [61] rlang_0.4.6          reshape2_1.4.4       AnnotationDbi_1.46.1
##  [64] munsell_0.5.0        cellranger_1.1.0     tools_3.6.3         
##  [67] cli_2.0.2            generics_0.0.2       RSQLite_2.2.0       
##  [70] broom_0.5.5          evaluate_0.14        yaml_2.2.1          
##  [73] ModelMetrics_1.2.2.2 bit64_0.9-7          fs_1.4.0            
##  [76] zip_2.0.4            caTools_1.18.0       xml2_1.2.5          
##  [79] compiler_3.6.3       rstudioapi_0.11      curl_4.3            
##  [82] ggsignif_0.6.0       reprex_0.3.0         stringi_1.4.6       
##  [85] highr_0.8            Matrix_1.2-18        vctrs_0.3.1         
##  [88] pillar_1.4.4         lifecycle_0.2.0      data.table_1.12.8   
##  [91] bitops_1.0-6         R6_2.4.1             KernSmooth_2.23-17  
##  [94] rio_0.5.16           IRanges_2.18.3       codetools_0.2-16    
##  [97] MASS_7.3-51.6        gtools_3.8.2         assertthat_0.2.1    
## [100] withr_2.2.0          S4Vectors_0.22.1     parallel_3.6.3      
## [103] hms_0.5.3            rpart_4.1-15         timeDate_3043.102   
## [106] class_7.3-17         rmarkdown_2.1        TTR_0.23-6          
## [109] pROC_1.16.2          Biobase_2.44.0       lubridate_1.7.4